Испольуется три задачи:
Сеть состоит из lif AdEx нейронов
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.net.lifadex import SNNlifadex
from cgtasknet.tasks.reduce import (
CtxDMTaskParameters,
DMTaskParameters,
DMTaskRandomModParameters,
MultyReduceTasks,
RomoTaskParameters,
RomoTaskRandomModParameters,
)
from norse.torch.functional.lif_adex import LIFAdExParameters
from tqdm import tqdm
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
device=device(type='cuda', index=0)
import os
def plot_results(inputs, target_outputs, outputs):
if isinstance(inputs, torch.Tensor) and isinstance(target_outputs, torch.Tensor):
inputs, t_outputs = (
inputs.detach().cpu().numpy(),
target_outputs.detach().cpu().numpy(),
)
for bath in range(min(batch_size, 20)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(141)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i].T, label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(142)
plt.title("Task code (context)")
plt.xticks(np.arange(1, len(tasks) + 1), sorted(tasks), rotation=90)
plt.yticks([])
for i in range(3, inputs.shape[-1]):
plt.plot([i - 2] * 2, [0, inputs[0, bath, i]])
plt.tight_layout()
ax3 = fig.add_subplot(143)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
ax4 = fig.add_subplot(144)
plt.title("Real output")
plt.xlabel("$time, ms$")
for i in range(outputs.shape[-1]):
plt.plot(
outputs.detach().cpu().numpy()[:, bath, i], label=rf"$out_{i + 1}$"
)
plt.legend()
plt.tight_layout()
if not os.path.exists("figures"):
os.mkdir("figures")
plt.savefig(f"figures{os.sep}network_outputs_{name}_batch_{bath}.pdf")
plt.show()
plt.close()
batch_size = 100
number_of_epochs = 2000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
romo=RomoTaskParameters(
delay=0.1,
positive_shift_delay_time=1.4,
trial_time=0.1,
positive_shift_trial_time=0.2,
),
)
dm_parameters = DMTaskRandomModParameters(
dm=DMTaskParameters(trial_time=0.1, positive_shift_trial_time=0.8)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)
sigma = 0.1
tasks = ["RomoTask1", "RomoTask2", "DMTask1", "DMTask2", "CtxDMTask1", "CtxDMTask2"]
task_dict = {
tasks[0]: romo_parameters,
tasks[1]: romo_parameters,
tasks[2]: dm_parameters,
tasks[3]: dm_parameters,
tasks[4]: ctx_parameters,
tasks[5]: ctx_parameters,
}
Task = MultyReduceTasks(
tasks=task_dict, batch_size=batch_size, delay_between=0, enable_fixation_delay=True
)
print("Task parameters:")
for key in task_dict:
print(f"{key}:\n{task_dict[key]}\n")
print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")
Task parameters: RomoTask1: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) RomoTask2: RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2) DMTask1: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) DMTask2: DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2) CtxDMTask1: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) CtxDMTask2: CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None)) inputs/outputs: 9/3
inputs, t_outputs = Task.dataset(n_trials=1)
for bath in range(min(batch_size, 10)):
fig = plt.figure(figsize=(15, 3))
ax1 = fig.add_subplot(131)
plt.title("Inputs")
plt.xlabel("$time, ms$")
plt.ylabel("$Magnitude$")
for i in range(3):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax2 = fig.add_subplot(132)
plt.title("Task code (context)")
plt.xlabel("$time, ms$")
for i in range(3, inputs.shape[-1]):
plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
plt.legend()
plt.tight_layout()
ax3 = fig.add_subplot(133)
plt.title("Target output")
plt.xlabel("$time, ms$")
for i in range(t_outputs.shape[-1]):
plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
plt.legend()
plt.tight_layout()
plt.show()
plt.close()
del inputs
del t_outputs
feature_size, output_size = Task.feature_and_act_size
hidden_size = 450
neuron_parameters = LIFAdExParameters(
v_th=torch.as_tensor(0.65),
tau_ada_inv=0.5 + (6 - 0.5) * torch.rand(hidden_size).to(device),
alpha=100,
method="super",
# rho_reset = torch.as_tensor(5)
)
model = SNNlifadex(
feature_size,
hidden_size,
output_size,
neuron_parameters=neuron_parameters,
tau_filter_inv=500,
).to(device)
learning_rate = 1e-2
class RMSELoss(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, yhat, y):
return torch.sqrt(self.mse(yhat, y))
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Если память не позволяет, то необходимо генерировать каждую эпоху в основном цикле обучения
if False:
list_inputs = []
list_t_outputs = []
for i in tqdm(range(number_of_epochs)):
temp_input, temp_t_output = Task.dataset()
temp_input.astype(dtype=np.float16)
temp_t_output.astype(dtype=np.float16)
temp_input[:, :, :] += np.random.normal(0, sigma, size=temp_input.shape)
list_inputs.append(temp_input)
list_t_outputs.append(temp_t_output)
from cgtasknet.instruments.instrument_accuracy_network import correct_answer
from cgtasknet.net.states import LIFAdExRefracInitState
name = f"Train_dm_and_romo_task_reduce_lif_adex_without_refrac_random_delay_long_a_alpha_{neuron_parameters.alpha}_N_{hidden_size}"
init_state = LIFAdExRefracInitState(batch_size, hidden_size, device=device)
running_loss = 0
for i in tqdm(range(2000)):
inputs, target_outputs = Task.dataset()
inputs[:, :, :3] += np.random.normal(0, sigma, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
optimizer.zero_grad()
# forward + backward + optimize
outputs, _ = model(inputs)
loss = criterion(outputs, target_outputs)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 10 == 9:
with open("log_multy.txt", "a") as f:
f.write("epoch: {:d} loss: {:0.5f}\n".format(i + 1, running_loss / 10))
running_loss = 0.0
with torch.no_grad():
torch.save(
model.state_dict(),
name,
)
if i % 10 == 9:
result = 0
for j in range(10):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs += np.random.normal(0, 0.01, size=inputs.shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = (
torch.from_numpy(target_outputs).type(torch.float).to(device)
)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 10 * 100
with open("accuracy_multy.txt", "a") as f:
f.write(f"ecpoch = {i}; correct/all = {accuracy}\n")
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
print("Finished Training")
100%|██████████| 2000/2000 [3:23:58<00:00, 6.12s/it]
Finished Training
np.random.normal(0, 0.01, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.01, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:18<00:00, 1.98s/it]
93.62
np.random.normal(0, 0.05, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.05, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
# del inputs
# del target_outputs
# torch.cuda.empty_cache()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:14<00:00, 1.95s/it]
93.79
np.random.normal(0, 0.1, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.1, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:17<00:00, 1.97s/it]
93.27
np.random.normal(0, 0.5, size=(inputs.shape))
result = 0
for j in tqdm(range(100)):
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
inputs, target_outputs = Task.dataset(1, delay_between=0)
inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
inputs = torch.from_numpy(inputs).type(torch.float).to(device)
target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
outputs = model(inputs)[0]
answers = correct_answer(
outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
)
result += torch.sum(answers).item()
accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
del inputs
except:
pass
try:
del target_outputs
except:
pass
try:
del outputs
except:
pass
torch.cuda.empty_cache()
100%|██████████| 100/100 [03:16<00:00, 1.97s/it]
82.98
inputs = 0
outputs = 0
tau_ada_inv_distrib = neuron_parameters.tau_ada_inv.cpu().numpy()
np.save(f"tau_ada_inv_alpha={neuron_parameters.alpha}", tau_ada_inv_distrib)
lines = []
with open("accuracy_multy.txt", "r") as f:
while line := f.readline():
lines.append(float(line.split("=")[2].strip()))
plt.figure(figsize=(8, 5))
plt.plot([*range(9, 2000, 10)], lines, ".", linestyle="--", markersize=5)
plt.ylabel(r"Accuracy%")
plt.xlabel(r"Epochs")
Text(0.5, 0, 'Epochs')